#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cmake .. -DGPU_ARCHS=87

cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(tensorrt_plugin LANGUAGES CXX CUDA)
include(cmake/modules/set_ifndef.cmake)

set_ifndef(TRT_LIB_DIR ${CMAKE_BINARY_DIR})
set_ifndef(TRT_OUT_DIR ${CMAKE_BINARY_DIR})


file(STRINGS "/usr/include/aarch64-linux-gnu/NvInferVersion.h" VERSION_STRINGS REGEX "#define NV_TENSORRT_.*")

foreach(TYPE MAJOR MINOR PATCH BUILD)
    string(REGEX MATCH "NV_TENSORRT_${TYPE} [0-9]" TRT_TYPE_STRING ${VERSION_STRINGS})
    string(REGEX MATCH "[0-9]" TRT_${TYPE} ${TRT_TYPE_STRING})
endforeach(TYPE)

foreach(TYPE MAJOR MINOR PATCH)
    string(REGEX MATCH "NV_TENSORRT_SONAME_${TYPE} [0-9]" TRT_TYPE_STRING ${VERSION_STRINGS})
    string(REGEX MATCH "[0-9]" TRT_SO_${TYPE} ${TRT_TYPE_STRING})
endforeach(TYPE)

set(TRT_VERSION "${TRT_MAJOR}.${TRT_MINOR}.${TRT_PATCH}" CACHE STRING "TensorRT project version")
set(TRT_SOVERSION "${TRT_SO_MAJOR}" CACHE STRING "TensorRT library so version")
message("Building for TensorRT version: ${TRT_VERSION}, library version: ${TRT_SOVERSION}")

# C++17
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_FLAGS "-Wno-deprecated-declarations ${CMAKE_CXX_FLAGS} -DBUILD_SYSTEM=cmake_oss")
set(CMAKE_CXX_FLAGS  "${CMAKE_CXX_FLAGS} -Wall -O3 -Wfatal-errors -w")
message("CXX_FLAGS: ${CMAKE_CXX_FLAGS}")

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wno-deprecated-declarations --expt-relaxed-constexpr")


############################################################################################
# Cross-compilation settings

set_ifndef(TRT_PLATFORM_ID "aarch64")
message(STATUS "Targeting TRT Platform: ${TRT_PLATFORM_ID}")

############################################################################################
# Dependencies

set(DEFAULT_CUDA_VERSION 11.4)
set(DEFAULT_CUDNN_VERSION 8.6)
set(DEFAULT_PROTOBUF_VERSION 3.6.1)


# Dependency Version Resolution
set_ifndef(CUDA_VERSION ${DEFAULT_CUDA_VERSION})
message(STATUS "CUDA version set to ${CUDA_VERSION}")
set_ifndef(CUDNN_VERSION ${DEFAULT_CUDNN_VERSION})
message(STATUS "cuDNN version set to ${CUDNN_VERSION}")
set_ifndef(PROTOBUF_VERSION ${DEFAULT_PROTOBUF_VERSION})
message(STATUS "Protobuf version set to ${PROTOBUF_VERSION}")

find_package(CUDA ${CUDA_VERSION} REQUIRED)
include_directories(
    ${CUDA_INCLUDE_DIRS}
    #${CUDNN_ROOT_DIR}/include
    #${TRT_ROOT}/include
    "/usr/include/aarch64-linux-gnu"
)

message("CUDA_INCLUDE_DIRS: ${CUDA_INCLUDE_DIRS}")

link_directories("/usr/lib/aarch64-linux-gnu")


find_library(CUDNN_LIB cudnn HINTS "/usr/lib/aarch64-linux-gnu")
find_library(CUBLAS_LIB cublas HINTS "/usr/local/cuda-11.4/lib64")
find_library(CUBLASLT_LIB cublasLt HINTS "/usr/local/cuda-11.4/lib64")
find_library(CUDART_LIB cudart HINTS "/usr/local/cuda-11.4/lib64")


message("cudnn lib dir: ${CUDNN_LIB}")
message("CUBLAS_LIB lib dir: ${CUBLAS_LIB}")
message("CUBLASLT_LIB lib dir: ${CUBLASLT_LIB}")
message("CUDART_LIB lib dir: ${CUDART_LIB}")

find_library(RT_LIB rt)

message("CUDA_TOOLKIT_ROOT_DIR: ${CUDA_TOOLKIT_ROOT_DIR}")

############################################################################################
# CUDA targets

if (DEFINED GPU_ARCHS)
  message(STATUS "GPU_ARCHS defined as ${GPU_ARCHS}. Generating CUDA code for SM ${GPU_ARCHS}")
  separate_arguments(GPU_ARCHS)
else()
  string(REGEX MATCH "aarch64" IS_ARM "${TRT_PLATFORM_ID}")
  if (IS_ARM)
    # Xavier (SM72) only supported for aarch64.
    list(APPEND GPU_ARCHS 72)
  endif()

  if (CUDA_VERSION VERSION_GREATER_EQUAL 11.0)
    # Ampere GPU (SM80) support is only available in CUDA versions > 11.0
    list(APPEND GPU_ARCHS 80)
  endif()
  if (CUDA_VERSION VERSION_GREATER_EQUAL 11.1)
    list(APPEND GPU_ARCHS 86 87)
  endif()

  message(STATUS "GPU_ARCHS is not defined. Generating CUDA code for default SMs: ${GPU_ARCHS}")
endif()


add_subdirectory(plugins)


